---
title: "Analysis TDM with Lasso"
author: "Nicolas Banholzer"
date: "13/01/2019"
output: html_document
---

```{r setup, include=FALSE}
knitr::opts_chunk$set(echo = FALSE, warning = FALSE, message = FALSE)

suppress_all <- function(x) suppressWarnings( suppressMessages( x ))
```

## Setup 

### Libraries

```{r, include=FALSE}
library(caret)
library(glmnet)
library(textcat)
library(tidyverse)
library(tm)
source("helper.R")
library(tikzDevice)
```

### Load Data

```{r}
df <- readxl::read_excel("../Data/patents_sample_final.xlsx", sheet = "classification_abstracts")

df <- df %>% 
  left_join(y = readxl::read_excel("../Data/patents_sample_final.xlsx", sheet = "Full-text abstracts_EP"), 
            by = c("patent_number" = "Patent number")) %>% 
  left_join(y = readxl::read_excel("../Data/patents_sample_final.xlsx", sheet = "Full-text abstracts_US"), 
            by = c("patent_number" = "Patent number")) %>%
  dplyr::rename(label = classification, text1 = Abstract_text.x, text2 = Abstract_text.y) %>%
  mutate(label = tolower(label), text = ifelse(is.na(text1), text2, text1)) %>%
  mutate(label = ifelse(grepl("product", label), "product", 
                        ifelse(grepl("process$", label), "process", NA)),
         text = text) %>%
  dplyr::select(patent_number, label, keywords, text) %>%
  na.omit()
```


## Descriptives

```{r}
summary(df$label)
```

```{r}
summary(textcat(df$text) %>% as.factor)
```

```{r}
ggplot(data = data_frame(x=nchar(df$text)), aes(x)) + geom_histogram()
```

```{r}
# Duplicated abstracts
sum(duplicated(df$text))
```


## Stratification

```{r}
df <- df %>%
  mutate(gid = group_indices(., text)) %>%
  group_by(gid) %>%
  mutate(n = n(), 
         n_product = sum(grepl("product", label)), 
         n_process = sum(grepl("process", label))) %>% 
  ungroup() %>%
  dplyr::select(-gid)
df %>%
  dplyr::filter(n > 1) %>%
  dplyr::select(-text)
```

```{r, echo=TRUE}
df <- df %>%
  mutate(label = ifelse(n_product >= n_process, "product", "process") ) %>%
  dplyr::filter(grepl("english", textcat(text))) %>%
  dplyr::filter(!duplicated(text)) %>%
  dplyr::select(patent_number, keywords, label, text)
df$label <- as.factor(df$label)
```

Removing duplicates $\rightarrow$ `r nrow(df)` observations left.

```{r}
write.csv(df, file = "../Data/patents_sample_final_preprocessed_abstracts.csv", row.names = F, fileEncoding ="utf-8")
```

## TDM

```{r}
gen_tdm <- function(text, ref_tdm = NULL) {
  tdm <- suppress_all( Corpus(VectorSource(text), 
                              readerControl = list(language = "english")) %>%
    tm_map(PlainTextDocument) %>%
    tm_map(stripWhitespace) %>%
    tm_map(removePunctuation) %>%
    tm_map(content_transformer(tolower)) %>%
    tm_map(removeNumbers) %>%
    tm_map(removeWords, stopwords("english")) %>%
    tm_map(stemDocument, language = "english") %>%
#   tm_map(lemmatize_strings) %>%
    DocumentTermMatrix(control = list(wordLengths = c(4, Inf))) %>% # No inverse weighting
    removeSparseTerms(0.95) )
  
  tdm <- tdm %>% as.matrix()
  
  if (!is.null(ref_tdm)) {
    tdm_to_ref <- matrix(0, nrow = nrow(tdm), ncol = ncol(ref_tdm))
    colnames(tdm_to_ref) <- colnames(ref_tdm)
    
    match_cols <- colnames(tdm)[which(colnames(tdm) %in% colnames(ref_tdm))]
    
    if (length(match_cols) > 0) {
      tdm_to_ref[ ,match_cols] <- tdm[ ,match_cols]
    }
    
    return(tdm_to_ref)
  } else {
    return(tdm)
  }
}

tdm_mat <- gen_tdm(df$text)

colnames(tdm_mat)
```

## Model

```{r}
set.seed(0)
```

```{r}
tid <- createDataPartition(df$label, times = 1, p = .7)$Resample1
y_train <- df$label[tid]
y_test <- df$label[-tid]
x_train <- tdm_mat[tid, ]
x_test <- tdm_mat[-tid, ]
```

```{r}
trainControl <- trainControl(method = "repeatedcv", number = 10, repeats = 5)
tuneGrid <- expand.grid(.alpha = c(.1, .5, .7, .9, .95, .99, 1), 
                        .lambda = c(seq(0, 10, by = 0.1), seq(11, 1000, by=1)))
glmnet_bi <- train(x = x_train, y = y_train, method = "glmnet", family = "binomial",
                    metric = "Accuracy", trControl = trainControl, tuneGrid = tuneGrid)
```

```{r, echo=TRUE}
best_alpha <- glmnet_bi$bestTune$alpha
best_alpha
```

```{r, echo=TRUE}
best_lambda <- glmnet_bi$bestTune$lambda
best_lambda
```

```{r, echo=TRUE}
coef(glmnet_bi$finalModel, best_lambda)
```

## Prediction

```{r, echo=TRUE}
y_train <- ifelse(y_train == 'product', 1, 0)
y_test <- ifelse(y_test == 'product', 1, 0)
yhat_train <- predict(glmnet_bi, x_train, type = "prob")[ ,2]
yhat_test <- predict(glmnet_bi, x_test, type = "prob")[ ,2]

get.auc(y_test, yhat_test)
```

```{r}
pl_roc <- plot_roc(name = "Abstracts", response = y_test, predictor = yhat_test,
                   show_top_left = T)
pl_roc

tikz("../Doc/abstracts_roc.tex", width = 7 / cm(1), height = 7 / cm(1))
pl_roc
dev.off()
```

```{r}
t_best_train <- get.optimal_cutoff(y_train, yhat_train)
t_best_test <- get.optimal_cutoff(y_test, yhat_test)
```

```{r}
yhat_train_label <- ifelse(yhat_train >= t_best_train, 1, 0)
yhat_test_label <- ifelse(yhat_test >= t_best_test, 1, 0)
```

```{r}
results_train <- data_frame(id = c(1:nrow(df))[tid], 
                            text = df$text[tid], 
                            true_label = y_train,
                            best_lambda = best_lambda,
                            best_alpha = best_alpha,
                            best_thresh = t_best_train,
                            pred_label = yhat_train_label, 
                            pred_prob = yhat_train)
results_test <- data_frame(id = c(1:nrow(df))[-tid], 
                           text = df$text[-tid],
                           true_label = y_test,
                           best_thresh = t_best_test,
                           pred_label = yhat_test_label,
                           pred_prob = yhat_test)
write.csv(file = "results_glmnet_abstracts_train.csv", 
          x = results_train, 
          row.names = FALSE, 
          fileEncoding ="utf-8")
write.csv(file = "results_glmnet_abstracts_test.csv", 
          x = results_test, 
          row.names = FALSE, 
          fileEncoding ="utf-8")
```

```{r}
false_preds <- results %>%
  mutate(true_label_int = as.integer(true_label) - 1) %>%
  mutate(test_loss = abs(true_label_int - pred_label)) %>%
  arrange(desc(-test_loss)) %>%
  filter(true_label %>% as.character() != pred_label %>% as.character())

false_preds$text[1:10]
```

## All abstract prediction

```{r}
abstracts <- list.files("../Data/AbstractBatches", full.names = T)
skip_batch = F
for(a in abstracts[1393:length(abstracts)]) {
  print(sprintf("Reading batch %s", a))
  A <- read_csv(a)
  if (skip_batch) {
    A <- rbind(A_skip, A)
  } 
  A_en <- filter(A, appln_abstract_lg == 'en')
  if (nrow(A_en) < 1000) {
    skip_batch <- T
    A_skip <- A
    print(sprintf("Removing file %s and appending to next file", a))
    file.remove(a)
  } else {
    skip_batch <- F
    print("Generating TDM")
    M <- gen_tdm(A_en$appln_abstract, ref_tdm = tdm_mat)
    print("Predicting label")
    p <- predict(glmnet_bi, M, type = "prob")[ ,2]
    l <- ifelse(p >= t_best_test, 1, 0)
    A$pred_prob <- NA
    A$pred_label <- NA
    A$pred_prob[which(A$appln_abstract_lg == 'en')] <- p
    A$pred_label[which(A$appln_abstract_lg == 'en')] <- l
    print(sprintf("Saving file %s with prediction", a))
    write.csv(A, a, row.names = F)
  }
}
```
